iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 11
0
自我挑戰組

Tensorflow.js初學筆記系列 第 11

Day10 TensorFlow.js: MNIST手寫數字辨識2

  • 分享至 

  • xImage
  •  

昨天大致上,對CNN有基本的了解
今天就照第8天的流程跑一次

首先昨天就有的新增模型

function createModel() {
    const model = tf.sequential();
    const IMAGE_WIDTH = 28;
    const IMAGE_HEIGHT = 28;
    const IMAGE_CHANNELS = 1;
    model.add(tf.layers.conv2d({
        inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
        kernelSize: 5,
        filters: 8,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPooling2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));
    model.add(tf.layers.conv2d({
        kernelSize: 5,
        filters: 16,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPooling2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));
    model.add(tf.layers.flatten());
    model.add(tf.layers.dense({
        units: 10,
        kernelInitializer: 'varianceScaling',
        activation: 'softmax'
    }));
    model.compile({
        optimizer: tf.train.adam(),
        loss: 'categoricalCrossentropy',
        metrics: ['accuracy'],
    });
    return model;
}

再來拿資料的部分直接引用MnistData.js去載入資料,老實說這個程式碼我看了很久也想不出來要怎麼改

async function getData(){
    const mnist_data = new MnistData();
    await mnist_data.load();
    return mnist_data;
}

再來就是轉換成tensor
因為裡面MnistData.js裡面包裝得差不多了
所以已只要調裡面的方法來用就好

function convertToTensor(mnist_data, method, size) {
    return tf.tidy(() => {
        const this_batch = mnist_data[method](size);
        return {
            inputs: this_batch.xs.reshape([size, 28, 28, 1]),
            labels: this_batch.labels
        }
    });
}

然後訓練模型

async function trainModel(model, t_data,v_data) {
    //每次訓練的樣本數
    const batchSize = 500;
    //訓練多少代
    const epochs =10;
    return await model.fit(t_data.inputs, t_data.labels, {
        batchSize,
        epochs,
        shuffle: true,
        validationData: [v_data.inputs, v_data.labels],
        callbacks: tfvis.show.fitCallbacks(
            { name: 'Training Performance' },
            ['loss', 'val_loss', 'acc', 'val_acc'],
            { height: 200, callbacks: ['onEpochEnd'] }
        )
    });
}

然後就執行

async function runTensorFlow(){
    const model=createModel();
    const mnist_data= await getData();
    const traindata=convertToTensor(mnist_data,"nextTrainBatch",5000);
    const validationdata=convertToTensor(mnist_data,"nextTestBatch",1000);
    await trainModel(model,traindata,validationdata);
}
document.addEventListener('DOMContentLoaded', runTensorFlow);

這時候就會看到訓練的過程


上一篇
Day9 TensorFlow.js: MNIST手寫數字辨識
下一篇
Day11 TensorFlow 有關tensor本身的操作
系列文
Tensorflow.js初學筆記27
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言